Generative Adversarial Networks (GAN) is one of the most promising recent developments in Deep Learning. GAN, introduced by Ian Goodfellow in 2014, attacks the problem of unsupervised learning by training two deep networks, called Generator and Discriminator, that compete and cooperate with each other. In the course of training, both networks eventually learn how to perform their tasks.
GAN is almost always explained like the case of a counterfeiter (Generative) and the police (Discriminator). Initially, the counterfeiter will show the police a fake money. The police says it is fake. The police gives feedback to the counterfeiter why the money is fake. The counterfeiter attempts to make a new fake money based on the feedback it received. The police says the money is still fake and offers a new set of feedback. The counterfeiter attempts to make a new fake money based on the latest feedback. The cycle continues indefinitely until the police is fooled by the fake money because it looks real.
While the idea of GAN is simple in theory, it is very difficult to build a model that works. In GAN, there are two deep networks coupled together making back propagation of gradients twice as challenging. Deep Convolutional GAN (DCGAN) is one of the models that demonstrated how to build a practical GAN that is able to learn by itself how to synthesize new images. In this article, we discuss how a working DCGAN can be built using Keras 2.0 on Tensorflow 1.0 backend in less than 200 lines of code. We will train a DCGAN to learn how to write handwritten digits, the MNIST way.
In [1]:
%matplotlib inline
import os,random
import numpy as np
import theano as th
import theano.tensor as T
from keras.utils import np_utils
import keras.models as models
from keras.layers import Input,merge
from keras.layers.core import Reshape,Dense,Dropout,Activation,Flatten
from keras.layers.advanced_activations import LeakyReLU
from keras.activations import *
from keras.layers.wrappers import TimeDistributed
from keras.layers.noise import GaussianNoise
from keras.layers.convolutional import Convolution2D, MaxPooling2D, ZeroPadding2D, Deconv2D, UpSampling2D
from keras.layers.recurrent import LSTM
from keras.regularizers import *
from keras.layers.normalization import *
from keras.optimizers import *
from keras.datasets import mnist
import matplotlib.pyplot as plt
import seaborn as sns
import cPickle, random, sys, keras
from keras.models import Model
from IPython import display
sys.path.append("../common")
from keras.utils import np_utils
from tqdm import tqdm
In [2]:
img_rows, img_cols = 28, 28
# the data, shuffled and split between train and test sets
(X_train, y_train), (X_test, y_test) = mnist.load_data()
X_train = X_train.reshape(X_train.shape[0], 1, img_rows, img_cols)
X_test = X_test.reshape(X_test.shape[0], 1, img_rows, img_cols)
X_train = X_train.astype('float32')
X_test = X_test.astype('float32')
X_train /= 255
X_test /= 255
print np.min(X_train), np.max(X_train)
print('X_train shape:', X_train.shape)
print(X_train.shape[0], 'train samples')
print(X_test.shape[0], 'test samples')
In [3]:
def make_trainable(net, val):
net.trainable = val
for l in net.layers:
l.trainable = val
A discriminator that tells how real an image is, is basically a deep Convolutional Neural Network (CNN) as shown in Figure 1. For MNIST Dataset, the input is an image (28 pixel x 28 pixel x 1 channel). The sigmoid output is a scalar value of the probability of how real the image is (0.0 is certainly fake, 1.0 is certainly real, anything in between is a gray area). The difference from a typical CNN is the absence of max-pooling in between layers. Instead, a strided convolution is used for downsampling. The activation function used in each CNN layer is a leaky ReLU. A dropout between 0.4 and 0.7 between layers prevent over fitting and memorization. Listing 1 shows the implementation in Keras.
Figure 1. Discriminator of DCGAN tells how real an input image of a digit is. MNIST Dataset is used as ground truth for real images. Strided convolution instead of max-pooling down samples the image.
The generator synthesizes fake images. In Figure 2, the fake image is generated from a 100-dimensional noise (uniform distribution between -1.0 to 1.0) using the inverse of convolution, called transposed convolution. Instead of fractionally-strided convolution as suggested in DCGAN, upsampling between the first three layers is used since it synthesizes more realistic handwriting images. In between layers, batch normalization stabilizes learning. The activation function after each layer is a ReLU. The output of the sigmoid at the last layer produces the fake image. Dropout of between 0.3 and 0.5 at the first layer prevents overfitting. Listing 2 shows the implementation in Keras.
Figure 2. Generator model synthesizes fake MNIST images from noise. Upsampling is used instead of fractionally-strided transposed convolution.
In [4]:
shp = X_train.shape[1:]
print shp
dropout_rate = 0.25
opt = Adam(lr=1e-3)
dopt = Adam(lr=1e-4)
nch = 200
# Build Generative model
nch = 200
g_input = Input(shape=[100])
H = Dense(nch*14*14, init='glorot_normal')(g_input)
H = BatchNormalization(mode=2)(H)
H = Activation('relu')(H)
H = Reshape( [nch, 14, 14] )(H)
H = UpSampling2D(size=(2, 2))(H)
H = Convolution2D(nch/2, 3, 3, border_mode='same', init='glorot_uniform')(H)
H = BatchNormalization(mode=2)(H)
H = Activation('relu')(H)
H = Convolution2D(nch/4, 3, 3, border_mode='same', init='glorot_uniform')(H)
H = BatchNormalization(mode=2)(H)
H = Activation('relu')(H)
H = Convolution2D(1, 1, 1, border_mode='same', init='glorot_uniform')(H)
g_V = Activation('sigmoid')(H)
generator = Model(g_input,g_V)
generator.compile(loss='binary_crossentropy', optimizer=opt)
generator.summary()
# Build Discriminative model
d_input = Input(shape=shp)
H = Convolution2D(256, 5, 5, subsample=(2, 2), border_mode = 'same', activation='relu')(d_input)
H = LeakyReLU(0.2)(H)
H = Dropout(dropout_rate)(H)
H = Convolution2D(512, 5, 5, subsample=(2, 2), border_mode = 'same', activation='relu')(H)
H = LeakyReLU(0.2)(H)
H = Dropout(dropout_rate)(H)
H = Flatten()(H)
H = Dense(256)(H)
H = LeakyReLU(0.2)(H)
H = Dropout(dropout_rate)(H)
d_V = Dense(2,activation='softmax')(H)
discriminator = Model(d_input,d_V)
discriminator.compile(loss='categorical_crossentropy', optimizer=dopt)
discriminator.summary()
# Freeze weights in the discriminator for stacked training
def make_trainable(net, val):
net.trainable = val
for l in net.layers:
l.trainable = val
make_trainable(discriminator, False)
# Build stacked GAN model
gan_input = Input(shape=[100])
H = generator(gan_input)
gan_V = discriminator(H)
GAN = Model(gan_input, gan_V)
GAN.compile(loss='categorical_crossentropy', optimizer=opt)
GAN.summary()
In [ ]:
In [ ]:
In [5]:
def plot_loss(losses):
display.clear_output(wait=True)
display.display(plt.gcf())
plt.figure(figsize=(10,8))
plt.plot(losses["d"], label='discriminitive loss')
plt.plot(losses["g"], label='generative loss')
plt.legend()
plt.show()
In [6]:
def plot_gen(n_ex=16,dim=(4,4), figsize=(10,10) ):
noise = np.random.uniform(0,1,size=[n_ex,100])
generated_images = generator.predict(noise)
plt.figure(figsize=figsize)
for i in range(generated_images.shape[0]):
plt.subplot(dim[0],dim[1],i+1)
img = generated_images[i,0,:,:]
plt.imshow(img)
plt.axis('off')
plt.tight_layout()
plt.show()
In [ ]:
In [7]:
ntrain = 10000
trainidx = random.sample(range(0,X_train.shape[0]), ntrain)
XT = X_train[trainidx,:,:,:]
# Pre-train the discriminator network ...
noise_gen = np.random.uniform(0,1,size=[XT.shape[0],100])
generated_images = generator.predict(noise_gen)
X = np.concatenate((XT, generated_images))
n = XT.shape[0]
y = np.zeros([2*n,2])
y[:n,1] = 1
y[n:,0] = 1
make_trainable(discriminator,True)
discriminator.fit(X,y, nb_epoch=1, batch_size=32)
y_hat = discriminator.predict(X)
In [8]:
y_hat_idx = np.argmax(y_hat,axis=1)
y_idx = np.argmax(y,axis=1)
diff = y_idx-y_hat_idx
n_tot = y.shape[0]
n_rig = (diff==0).sum()
acc = n_rig*100.0/n_tot
print "Accuracy: %0.02f pct (%d of %d) right"%(acc, n_rig, n_tot)
In [9]:
# set up loss storage vector
losses = {"d":[], "g":[]}
The adversarial model is just the generator-discriminator stacked together as shown in Figure 3. The Generator part is trying to fool the Discriminator and learning from its feedback at the same time. Listing 4 shows the implementation using Keras code. The training parameters are the same as in the Discriminator model except for a reduced learning rate and corresponding weight decay.
Figure 3. The Adversarial model is simply generator with its output connected to the input of the discriminator. Also shown is the training process wherein the Generator labels its fake image output with 1.0 trying to fool the Discriminator.
Training is the hardest part. We determine first if Discriminator model is correct by training it alone with real and fake images. Afterwards, the Discriminator and Adversarial models are trained one after the other. Figure 4 shows the Discriminator Model while Figure 3 shows the Adversarial Model during training. Listing 5 shows the training code in Keras.
In [10]:
def train_for_n(nb_epoch=5000, plt_frq=25,BATCH_SIZE=32):
for e in tqdm(range(nb_epoch)):
# Make generative images
image_batch = X_train[np.random.randint(0,X_train.shape[0],size=BATCH_SIZE),:,:,:]
noise_gen = np.random.uniform(0,1,size=[BATCH_SIZE,100])
generated_images = generator.predict(noise_gen)
# Train discriminator on generated images
X = np.concatenate((image_batch, generated_images))
y = np.zeros([2*BATCH_SIZE,2])
y[0:BATCH_SIZE,1] = 1
y[BATCH_SIZE:,0] = 1
make_trainable(discriminator,True)
d_loss = discriminator.train_on_batch(X,y)
losses["d"].append(d_loss)
# train Generator-Discriminator stack on input noise to non-generated output class
noise_tr = np.random.uniform(0,1,size=[BATCH_SIZE,100])
y2 = np.zeros([BATCH_SIZE,2])
y2[:,1] = 1
make_trainable(discriminator,False)
g_loss = GAN.train_on_batch(noise_tr, y2 )
losses["g"].append(g_loss)
# Updates plots
if e%plt_frq==plt_frq-1:
plot_loss(losses)
plot_gen()
In [11]:
train_for_n(nb_epoch=250, plt_frq=25,BATCH_SIZE=128)
In [ ]:
K.set_value(opt.lr, 1e-4)
K.set_value(dopt.lr, 1e-5)
train_for_n(nb_epoch=100, plt_frq=10,BATCH_SIZE=128)
In [ ]:
K.set_value(opt.lr, 1e-5)
K.set_value(dopt.lr, 1e-6)
train_for_n(nb_epoch=100, plt_frq=10,BATCH_SIZE=256)
In [10]: